from torch.utils.data import Dataset
from src.data.base_dataset import BaseDataset


class TurbineDataset(BaseDataset):
    def __init__(self, data, space_graph, seq, pred_len, target_turbine, target_feature, base_feature, logfile=None, single_point=True):
        super(TurbineDataset, self).__init__(data, space_graph, seq, pred_len, target_turbine, target_feature, base_feature, logfile)

    def getdata(self, index):
        return self.data[index: index + self.seq, :, :] # , self.data[index + self.seq: index + self.seq + self.pred_len, :, ]